package org.zjvis.datascience.spark.algorithm;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.mllib.clustering.dbscan.DBSCAN;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.Tuple2;
import scala.collection.JavaConversions;

import java.util.*;

/**
 * @description Spark-DBScan算子 （已废弃）
 * @date 2021-12-23
 */
@Deprecated
public class DBscanAlgorithm extends BaseAlgorithm {
    private double eps;

    private int minPoints;

    private int maxPointsPerPartition;

    private String[] featureCols;

    private String targetTable;

    private String sourceTable;

    public DBscanAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);
        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("idCol"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(OptionHelper.getSpecOption("eps", "eps", "eps", true));
        options.addOption(OptionHelper.getSpecOption("minp", "minPoints", "min points", true));
        options.addOption(OptionHelper.getSpecOption("maxp", "maxPoints", "max points per partition", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }

        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("idcol") || cmd.hasOption("idCol")) {
            idCol = cmd.getOptionValue("idCol", UtilTool.DEFAULT_ID_COL);
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("eps")) {
            eps = Double.parseDouble(cmd.getOptionValue("eps"));
        }
        if (cmd.hasOption("minp") || cmd.hasOption("minPoints")) {
            minPoints = Integer.parseInt(cmd.getOptionValue("minPoints"));
        }
        if (cmd.hasOption("maxp") || cmd.hasOption("maxPoints")) {
            maxPointsPerPartition = Integer.parseInt(cmd.getOptionValue("maxPoints"));
        }
        return true;
    }

    public boolean beginAlgorithm() {

        Dataset<Row> dataset;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, idCol);
            dataset = sparkSession.sql(sql);
        } else {
            dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol)
                    .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
        }

        Map<String, DataType> schema = UtilTool.getDataTypeMaps(dataset.schema(), true);

        Map<String, Tuple2<DataType, DataType>> needChangeTypes = UtilTool.getNeedChangeDoubleTypeKeys(featureCols, schema);

        VectorAssembler assembler = new VectorAssembler()
                .setHandleInvalid("skip")
                .setInputCols(featureCols)
                .setOutputCol("features");
        Dataset<Row> transDF = assembler.transform(dataset).select(idCol, "features").persist(StorageLevel.MEMORY_AND_DISK());
        RDD<Vector> vectorRDD = transDF.toJavaRDD().map(row -> Vectors.fromML((DenseVector) row.get(1))).rdd().persist(StorageLevel.MEMORY_AND_DISK());

        DBSCAN dbscan = DBSCAN.train(
                vectorRDD,
                eps,
                minPoints,
                maxPointsPerPartition
        );
        JavaRDD<Row> rdd = dbscan.labeledPoints().toJavaRDD().map(p -> {
            List<Object> features = new ArrayList<>();
            Vector vector = p.vector();
            for (Object e : vector.toArray()) {
                features.add(e);
            }
            features.add(String.valueOf(p.cluster()));
            features.add(p.flag().toString());
            return RowFactory.create(features.toArray());
        });

//        schema = UtilTool.getDataTypeMaps(dataset.schema(), true);
        Map<String, DataType> map = new HashMap<>();
        map.put("cluster_id", DataTypes.StringType);
        map.put("flag", DataTypes.StringType);
        StructType newScheam = UtilTool.createSchema(schema, Arrays.asList(featureCols), 0, map);

        Dataset<Row> df = sparkSession.createDataFrame(rdd, newScheam);

        df = UtilTool.changeBackDataTypeForKey(df, needChangeTypes);

        if (isHive) {
            this.saveHiveTable(df, targetTable);
        } else {
            UtilTool.saveGreenplumTable(df, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        vectorRDD.unpersist(false);
        return true;
    }
}